% Copyright 2009 Neurosciences Research Foundation, Incorporated
% [cost_measure spikeout mfr all_target_mfr] = spnet( conn_data_in, draw,
%       testing, stop_secoptional_pattern, all_target_mfr)
% testing==true means don't reinit weights; use last ones.
function [cost_measure spikeout mfr all_target_mfr] = spnet( conn_data_in, draw, testing, stop_sec, optional_pattern, all_target_mfr)

global s sd syntype D N syndata 
global  Ne IN A AI
global vhist uhist ehist ampahist nmdahist gabaahist gababhist itothist 
global w r

if ~exist('testing')
    testing=false;
end



N_connections = 6;
conn_vars = 3;
N_groups = 3;
group_vars = 7;

if ~exist( 'draw' )
    draw = false;
end

% excitatory neurons   % inhibitory neurons      % total number 
Ne=50*2;                Ni=25*2;                   N=Ne+Ni;

% set up group sizes.
IN = 1:Ne/2;
A = max(IN)+1:Ne;
AI = max(A)+1:N;
%N = max(AI);

if exist('conn_data_in') && ~isempty(conn_data_in)
    if size(conn_data_in,2)==1
        group_data_in = reshape(conn_data_in(N_connections*conn_vars+1:end), N_groups, group_vars);
        conn_data_in = reshape(conn_data_in(1:N_connections*conn_vars), N_connections, conn_vars);
        if draw
            group_name{1}='IN';
            group_name{2}='A';
            group_name{3}='AI';
            conn_name{1}='IN->A';
            conn_name{2}='IN->AI';
            conn_name{3}='A->AI';
            conn_name{4}='A->A';
            conn_name{5}='AI->A';
            conn_name{5}='AI->AI';
            
            disp('          background ampa, gaba,   nmdagain, gababgain,      f,     d,       u');
            for i=1:N_groups
                fprintf('%10s ',group_name{i});
                fprintf('%10.3f ',group_data_in(i,:))
                fprintf('\n');
            end
            disp('          std        wt range,      radius')
            for i=1:N_connections
                fprintf('%10s ',conn_name{i});
                fprintf('%10.3f ',conn_data_in(i,:))
                fprintf('\n');
            end
        end
    end
    
else
                      % background ampa, gaba,   nmdagain, gababgain,   f   d u
%    group_data_in(1,:)=    [0 0 .5 .1 1000 800 .5];
 %   group_data_in(2,:)=    [0 0 .5 .1 1000 800 .5];
 %   group_data_in(3,:)=    [0 0 .5 0 20 700 .2];
    
    group_data_in(1,:)=    [0 0 0.5 0.2 1000 800 .5]; % IN
    group_data_in(2,:)=    [0 0 0.3 0.1 1000 800 .5]; % A
    group_data_in(3,:)=    [0 0 0.5 0.2 20 700 .2];   % AI

    
end
start_save_sec = 0;

if ~exist('stop_sec')
    stop_sec = 1;
end


gampa_background=zeros(N,1);
ggabaa_background=zeros(N,1);
gampa_background(IN)=group_data_in(1,1);
ggabaa_background(IN)=group_data_in(1,2);
gampa_background(A)=group_data_in(2,1);
ggabaa_background(A)=group_data_in(2,2);
gampa_background(AI)=group_data_in(3,1);
ggabaa_background(AI)=group_data_in(3,2);




epoch_ms = 1000;
npatterns = 4;
isi_ms = floor(epoch_ms/npatterns);

if exist('optional_pattern') && optional_pattern
    r=optional_pattern;
    
    w=cos(-1.5*pi:1.5*pi/(Ne/4):1.5*pi);
    positive_w=w;
    positive_w(w<0)=0;
    for i = 1:npatterns
        junk = zeros(size(IN));
        [z maxpoint]=max(conv2(r(i,:), w,'same'));
        junk(maxpoint)=1;
        junk = 40*conv2(junk,positive_w,'same');

        all_target_mfr(i,:)=junk;
    end
else
    
    r=make_patterns(npatterns,length(IN));

    
%     r=.2*rand(npatterns,length(IN));
%     pat=[.25 .5 1 .5 .25];
%     pat_width = length(pat);
%     spot=[6 21 11 16];
%     for i = 1:npatterns
%         r(i,spot(i):spot(i)+pat_width-1)=pat;
%     end
%     
%     % DEBUG
%     r=.5*rand(npatterns,length(IN));
%     
%     r(1,:)=r(1,:)*.5;
%     r(2,:)=r(2,:)*.75;
%     r(3,:)=r(3,:)*1;
%     r(4,:)=r(4,:)*.25;


    w=cos(-1.5*pi:1.5*pi/(Ne/4):1.5*pi);
    positive_w=w;
    positive_w(w<0)=0;
    for i = 1:npatterns
        junk = zeros(size(IN));
        [z maxpoint]=max(conv2(r(i,:), w,'same'));
        junk(maxpoint)=1;
        junk = 40*conv2(junk,positive_w,'same');

        all_target_mfr(i,:)=junk;
    end

        
end


for i = 1:npatterns
    patterns{i} = r(i,:);
end




search_params = [1 3 4];
conn_data = zeros(6,9);
    

conn_arg = 1;

%function spikeout = spnet( stop_sec, start_save_sec,  conn_data)
% spikeout will be all spikes between start_savfe_sec and stop_sec.


learning = true;
  

spikeout=[];

% spnet.m: Spiking network with axonal conduction delays and STDP
% Created by Eugene M.Izhikevich.                February 3, 2004
% Modified to allow arbitrary delay distributions.  April 16,2008
% Rewritten by Jeffrey L. McKinstry, November, 2009 to allow anatomy.

D=5;                  % maximal conduction delay 


% neuron parameters from Izhikevich & Edelman (2008) RS and FS cells.
vr=([-60*ones(Ne,1);    -55*ones(Ni,1)]);
vt=([-50*ones(Ne,1);    -40*ones(Ni,1)]);
k=([3*ones(Ne,1);        1*ones(Ni,1)]);
C=([80*ones(Ne,1);      20*ones(Ni,1)]);
a=([0.01*ones(Ne,1);    0.15*ones(Ni,1)]);
b=([5*ones(Ne,1);        8*ones(Ni,1)]);
c=([-60*ones(Ne,1);     -55*ones(Ni,1)]);
d=([10*ones(Ne,1);    200*ones(Ni,1)]);
%d=([400*ones(Ne,1);    200*ones(Ni,1)]);
vp=([50*ones(Ne,1);      25*ones(Ni,1)]);

group = IN;
vr(group)= -60;
vt(group)= -50;
k(group)=   1.6;
C(group)= 200;
a(group)=0.1;
b(group)=   0; % Assuming it is in tonic mode all the time.
c(group)=   -60;
d(group)=   10;
vp(group)=  40;


%% connections

if ~testing
    
    s = cell(D,1);
    sd = cell(D,1);
    syntype = cell(D,1);
    for i = 1:D
        s{i} = sparse(N,N);
        syntype{i} = sparse(N,N);
        sd{i} = sparse(N,N);
    end




    syndata = [];


    i=1;


    %radius=2;
    radius=length(AI)/2;
    conn_data(i,:)=[0.5*radius 0 5*5 radius 3 3 .05 0 20]; 
    if exist('conn_data_in') && ~isempty(conn_data_in)
        % merge the supplied parameters with defaults.
        conn_data(i,search_params)=conn_data_in(i,:);    
    end
    conn = conn_data(i,:);i=i+1; 
    connect(IN,A,'rand', 'std', conn(1), 'initoffsetw', conn(2), 'initrangew', conn(3), 'radius', conn(4), 'delaymin', conn(5), 'delayrange', conn(6), 'lrate', conn(7), 'minw', conn(8), 'maxrangew', conn(9));


    %radius=2;
    radius=length(AI)/2;
    conn_data(i,:)=[0.5*radius 0 0 radius 1 2 .05 0 20]; 
    if exist('conn_data_in') && ~isempty(conn_data_in)

        % merge the supplied parameters with defaults.
        conn_data(i,search_params)=conn_data_in(i,:);    
    end
    conn = conn_data(i,:);i=i+1; 
    connect(IN,AI,'gauss', 'std', conn(1), 'initoffsetw', conn(2), 'initrangew', conn(3), 'radius', conn(4), 'delaymin', conn(5), 'delayrange', conn(6), 'lrate', conn(7), 'minw', conn(8), 'maxrangew', conn(9));


    radius=2;
    conn_data(i,:)=[0.5*radius 0 80 radius 1 2 .05 0 20]; 
    %conn_data(i,:)=[0.5*radius 0 80 radius 1 2 .05 0 20]; 
    if exist('conn_data_in') && ~isempty(conn_data_in)

        % merge the supplied parameters with defaults.
        conn_data(i,search_params)=conn_data_in(i,:);    
    end
    conn = conn_data(i,:);i=i+1; 
    connect(A,AI,'gauss', 'std', conn(1), 'initoffsetw', conn(2), 'initrangew', conn(3), 'radius', conn(4), 'delaymin', conn(5), 'delayrange', conn(6), 'lrate', conn(7), 'minw', conn(8), 'maxrangew', conn(9));


    radius = length(A)/2;
    not_used = pi;
    conn_data(i,:)=[not_used 0  5 radius 1 2 0.0 0 10];
    if exist('conn_data_in') && ~isempty(conn_data_in)

        % merge the supplied parameters with defaults.
        conn_data(i,search_params)=conn_data_in(i,:);    
    end
    conn = conn_data(i,:);i=i+1; 
    connect(A,A,'cos', 'std', conn(1), 'initoffsetw', conn(2), 'initrangew', conn(3), 'radius', conn(4), 'delaymin', conn(5), 'delayrange', conn(6), 'lrate', conn(7), 'minw', conn(8), 'maxrangew', conn(9));


    radius = length(A)/2;
    conn_data(i,:)=[not_used 0 400 radius 1 1 0.0 0 20]; 
    if exist('conn_data_in') && ~isempty(conn_data_in)
        % merge the supplied parameters with defaults.
        conn_data(i,search_params)=conn_data_in(i,:);    
    end
    conn = conn_data(i,:);i=i+1; 
    connect(AI,A,'cossurround', 'std', conn(1), 'initoffsetw', conn(2), 'initrangew', conn(3), 'radius', conn(4), 'delaymin', conn(5), 'delayrange', conn(6), 'lrate', conn(7), 'minw', conn(8), 'maxrangew', conn(9));


    radius = length(A)/2;
    conn_data(i,:)=[not_used 0 100 radius 1 1 0.0 0 20]; 
    if exist('conn_data_in') && ~isempty(conn_data_in)
        % merge the supplied parameters with defaults.
        conn_data(i,search_params)=conn_data_in(i,:);    
    end
    conn = conn_data(i,:);i=i+1; 
    connect(AI,AI,'cossurround', 'std', conn(1), 'initoffsetw', conn(2), 'initrangew', conn(3), 'radius', conn(4), 'delaymin', conn(5), 'delayrange', conn(6), 'lrate', conn(7), 'minw', conn(8), 'maxrangew', conn(9));
end


mfr=zeros(1,N);




%s=[6*ones(Ne,M);-5*ones(Ni,M)];         % synaptic weights

if learning && ~testing
    ltp_decay = exp(-1/20);
    ltpcurve = .1*cumprod([ 1 ltp_decay*ones(1,49)]);
    ltd_decay = exp(-1/20);
    ltdcurve = -.12*cumprod([ 1 ltd_decay*ones(1,49)]);
    stdpcurve = [ fliplr(ltpcurve)  ltdcurve 0];
end  

if draw
    mv = zeros(N,1); % for visualization
    vhist = zeros(N,1000);
    uhist = zeros(N,1000);
    ihist = zeros(N,1000);
    ehist = zeros(N,1000);
    ampahist = zeros(N,1000);
    nmdahist = zeros(N,1000);
    gabaahist = zeros(N,1000);
    gababhist = zeros(N,1000);
    itothist = zeros(N,1000);    
end


% restart for each simulation.
%rand('seed',1)

%% model conductances
%tau_ampa  = exp(-1/5);
tau_ampa  = ones(N,1)*exp(-1/5);
tau_nmda  = ones(N,1)*exp(-1/150);
tau_gabaa = ones(N,1)*exp(-1/6);
tau_gabab = ones(N,1)*exp(-1/150);

tau_ampa(A)=exp(-1/100);
tau_gabaa(A) = exp(-1/15);
tau_gabab(A) = exp(-1/150);


E_AMPA	= 0.0;
E_NMDA	= 0.0;
E_GABAa	= -70;
E_GABAb = -90;



gain_ampa=1*ones(N,1);
gain_nmda=.5*ones(N,1);
gain_gabaa=1*ones(N,1);
gain_gabab=0*ones(N,1);


gain_nmda(IN)=group_data_in(1,3);
gain_nmda(A)=group_data_in(2,3);
gain_nmda(AI)=group_data_in(3,3);

gain_gabab(IN)=group_data_in(1,4);
gain_gabab(A)=group_data_in(2,4);
gain_gabab(AI)=group_data_in(3,4);




% end model conductances.
%% stsp parameters:

tau_F1=exp(-1/1000)*ones(N,1);
tau_D1=exp(-1/800)*ones(N,1);
U=0.5*ones(N,1);

% AI
tau_F1(IN)=exp(-1/group_data_in(1,5));
tau_D1(IN)=exp(-1/group_data_in(1,6));
U(IN)=group_data_in(1,7);

tau_F1(A)=exp(-1/group_data_in(2,5));
tau_D1(A)=exp(-1/group_data_in(2,6));
U(A)=group_data_in(2,7);

tau_F1(AI)=exp(-1/group_data_in(3,5));
tau_D1(AI)=exp(-1/group_data_in(3,6));
U(AI)=group_data_in(3,7);



MAXSPIKESPERSEC= N/2*1000;
firings=zeros(MAXSPIKESPERSEC,2);

firehist=cell(N,1);


numspikes = 0;

snew = [];
for i=1:D
    snew = [snew s{i}];
end
snew= full(snew);
se=snew(1:max(A),:);
si=snew(max(A)+1:N,:);



%% simulation
for sec=1:stop_sec                      % simulation of 1 day
  for t=1:1000                          % simulation of 1 sec

      if mod(t,isi_ms)==1 || t==1
        I=zeros(N,1);
        Igabaa = I;Igabab = I; Iampa = I;Inmda = I;
        Ie=zeros(N,D);
        Ii=zeros(N,D);

        stsp = U;
        stf = zeros(N,1);
        stdep = zeros(N,1);

        % end stsp paramters
        %%
        v = -65*ones(N,1)+5*rand(N,1);                      % initial values
        u = 0.2*ones(N,1)+20*rand(N,1);                             % initial values

      end
      
    %randindex = ceil(N*rand);
    %I(randindex,1)=I(randindex,1)+20;                 % random thalamic input 
%    I(IN,1) = I(IN,1) + 20*patterns{ 1+floor((t-1)/250)};
    
    fired = find(v>=vp);                % indices of fired neurons
    efired = v'>=vp';
    efired = efired.*stsp';
    ifired = efired(max(A)+1:N);
    efired = efired(1:max(A));

    v(fired)=c(fired);  
    u(fired)=u(fired)+d(fired);

    stdep(fired)=stdep(fired)+stsp(fired);
    stf(fired)=stf(fired)+U(fired).*(1-U(fired)-stf(fired));

    
    % deliver all the synapses to post targets with correct delays.
    % I is a shift register for each neuron.
    %for i = 1:D
     %  Ie(:,i) = Ie(:,i) + (spfired(1:max(Achand))*s{i}(1:max(Achand),:))';
     %  Ii(:,i) = Ii(:,i) + (spfired(max(Achand)+1:N)*s{i}(max(Achand)+1:N,:))';
    %end
     %Ie = Ie + reshape( (spfired(1:max(Achand))*snew(1:max(Achand),:)) , N,D);
     %Ii = Ii + reshape( (spfired(max(Achand)+1:N)*snew(max(Achand)+1:N,:)) ,N,D);
    Ie = Ie + reshape( efired*se ,N,D);
    Ii = Ii + reshape( ifired*si ,N,D);

    Igabaa = Igabaa.*tau_gabaa + gain_gabaa.*Ii(:,1); 
    Igabab = Igabab.*tau_gabab + gain_gabab.*Ii(:,1); % This one should build slowly, but continue
        % to build with only a brief pulse.
    
    Iampa = Iampa.*tau_ampa + gain_ampa.*Ie(:,1);
    Inmda = Inmda.*tau_nmda + gain_nmda.*Ie(:,1); % should delay an extra ms for nmda.

    xx = (-80-v)/60;
	xx = xx.*xx;
	NMDAgate = xx./(1+xx);
    
    % DEBUG
%     w=cos(-1.5*pi:1.5*pi/(Ne/4):1.5*pi)';
%     w1=-w;w1(w1<0)=0;
%     w(w<0)=0;
%     w=[zeros(size(IN))' ;w(1:Ne/2); zeros(Ni,1)];
%     w1=[zeros(size(IN))' ;w1(1:Ne/2); zeros(Ni,1)];
%     gampa_rand=w*5.*abs(randn(size(w)));
%     ggabaa_rand=w1*20.*abs(randn(size(w)));
    
    gampa_rand  = abs(randn(N,1)).*gampa_background;
    ggabaa_rand = abs(randn(N,1)).*ggabaa_background;

    g= ( Iampa+ gampa_rand  +NMDAgate.*Inmda        +Igabaa+ ggabaa_rand +Igabab);
	E= ( (Iampa+gampa_rand)*E_AMPA             +NMDAgate.*Inmda*E_NMDA +(Igabaa+ggabaa_rand)*E_GABAa +Igabab*E_GABAb);
    Isyn = v.*g-E;
    %E(IN)=E(IN)+ 40*patterns{ 1+floor((t-1)/isi_ms)}';
	%Isyn(A) = Isyn(A) + -1000*patterns{ 1+floor((t-1)/isi_ms)}';
	Isyn(IN) = -1500*patterns{ 1+floor((t-1)/isi_ms)}';
%    Isyn(Achand(8:12))=[ -150 -300 -500 -300 -150 ];
    
	% breaking the equation up into two smaller steps prevents numerical instability
%    F = (0.04*v+5).*v + 140-u;
%	v = (v + 0.5*(F+E))./(1+0.5*g);		% for stability purposes

%	F=  (0.04*v+5).*v + 140-u
%	v = (v + 0.5*(F+E))./(1+0.5*g);		% for stability purposes
    
    
%     I = Igabaa + Igabab + Iampa; %+ NMDAgate.*Inmda;
%     I(IN) = I(IN) + 40*patterns{ 1+floor((t-1)/isi_ms)}';
%         
    v=v+0.5*(k.*(v-vr).*(v-vt)-u-Isyn)./C;    % for numerical 
    %v(v>104)=104;   % Nothing higher than Eca2+
    v(v>50)=50;   % Nothing higher than typical action potential
    v(v<-104)=-103;  % cheating on numerical integration. Nothing lower than EK

    u=u+0.5*a.*(b.*(v-vr)-u);                   % step is 0.5 ms
 
    v=v+0.5*(k.*(v-vr).*(v-vt)-u-Isyn)./C;    % for numerical 
    %v(v>104)=104;   % Nothing higher than Eca2+
    v(v>50)=50;   % Nothing higher than Eca2+
    v(v<-104)=-103;  % cheating on numerical integration. Nothing lower than EK

    u=u+0.5*a.*(b.*(v-vr)-u);                   % step is 0.5 ms
    

    
    % Short term synaptic plasticity.
    stf = stf.*tau_F1;
	stdep = stdep.*tau_D1;
    stsp = (1-stdep).*(U+stf); % By default, keep the value of the last spike?

    
    if draw, 
        ihist(:,t) = Ii(:,1); ehist(:,t)= Ie(:,1); 
        itothist(:,t) = Isyn;
        ampahist(:,t)= Iampa; nmdahist(:,t)= NMDAgate.*Inmda; 
        gabaahist(:,t)=Igabaa; gababhist(:,t)=Igabab;

    end
    
    Ie= [Ie(:,2:D) zeros(N,1)];  % rotate.     
    Ii= [Ii(:,2:D) zeros(N,1)];  % rotate.     
    
    
    % save the spike data.
    firings(numspikes+1:numspikes+length(fired),:)=[t*ones(length(fired),1),fired];
    numspikes = numspikes + length(fired);
    
    if learning && ~testing
        for i =1:length(fired)
            firehist{fired(i)}(end+1)=t;
        end
    end
    %imagesc(I);
  
    if draw
        vhist(:,t)=v;
        uhist(:,t)=u;
        
        tau = exp(-1/200);
        mfr(fired)=mfr(fired) + (1-tau)*1000;
        mfr=mfr*tau;

        mv = tau*mv + (1-tau)*v;
        if mod(t,250)==0
            figure(3);
            bar(mfr)
            drawnow;
            
            %figure(4)
            %plot(mv);
            %drawnow;
        end
    end
    
  end;
  
  
  disp(['t=' num2str(sec)])
  
  if draw
      
      figure(1)
      plot(firings(:,1),firings(:,2),'.');
      axis([0 1000 0 N]); drawnow;

      figure(2)
      hist(s{1}(find(s{1})),50);

      figure(4)
      subplot(7,1,1)
      imagesc(vhist,[-70 -60]); colorbar;
      title('vhist')
      subplot(7,1,2)
      imagesc(ihist); colorbar;
      title('inhib input history');
      subplot(7,1,3)
      imagesc(ehist); colorbar;
      title('excit input history');
      subplot(7,1,4)
      imagesc(ampahist); colorbar;
      title('ampa input history');
      subplot(7,1,5)
      imagesc(nmdahist); colorbar;
      title('nmda input history');
      subplot(7,1,6)
      imagesc(gabaahist); colorbar;
      title('gabaa input history');
      subplot(7,1,7)
      imagesc(gababhist); colorbar;
      title('gabab input history');
      
      figure(5)
      subplot(4,1,1)
      j=1;
      for i = 1:isi_ms:999
          sum_temp(:,j) = mean(itothist(:,i:i+99),2)  ;
          sum_tempi(:,j) = mean(ihist(:,i:i+99),2) ;
          sum_tempe(:,j) = mean(ehist(:,i:i+99),2) ;
          sum_tempv(:,j) = mean(vhist(:,i:i+99),2) ;
          j=j+1;
      end
      plot(sum_temp);
      size(sum_temp)
      title('Isyn')
      
      legend('1','2','3','4','5','6');
      subplot(4,1,2)
      plot(sum_tempe);
      title('ehist');
      size(sum_tempe)

      subplot(4,1,3)
      plot(sum_tempi);
      title('ihist');
      
      subplot(4,1,4)
      plot(sum_tempv);
      title('voltage hist');
      
  end
  % do some learning based on spike times.
  % sd in Eugene's spnet did not get used until after 1 second, so this is
  % equivalent.  Just makes for a cleaner implementation.
  
  
  if learning && ~testing
      for delay = 1:D
        % get all the nonzero elements from synaptic matrix.
        [pre_i post_i] = find( s{delay} );

        % Adjust these weights.
        for j = 1:length(pre_i)
            type = syntype{delay}(pre_i(j),post_i(j));
            w=s{delay}(pre_i(j),post_i(j));

            %if type == 6, w,  sd{delay}(pre_i(j),post_i(j)),end

            if syndata(type).max < .00001
                scale = 1;
            else
                scale = w/ syndata(type).max;
            end
            sd{delay}(pre_i(j),post_i(j)) = sd{delay}(pre_i(j),post_i(j)) + stdp(firehist{pre_i(j)},firehist{post_i(j)}, delay, stdpcurve,syndata(type).lrate, scale);

            %if type == 6, w,  sd{delay}(pre_i(j),post_i(j)),end
            s{delay}(pre_i(j),post_i(j)) = max( syndata(type).min ,min(syndata(type).max, syndata(type).lrate*.01 + w + sd{delay}(pre_i(j),post_i(j)))); 
        end

        sd{delay}=0.9*sd{delay};

      end
      % keep spikes near the seconds boundary for future STDP calculations.
      for i = 1:N
        firehist{i}=firehist{i}( find(firehist{i}> 950) )-1000;
      end
      
      
      normalize_weights;
      
      
      temp=(s{1}+s{2}+s{3}+s{4}+s{5});
      w=full(temp(IN,A));
      visualize_som(w,5,10)
      drawnow

  end
  
  if sec >= start_save_sec
      spikeout = [spikeout; [firings(1:numspikes,1)+1000*(sec-1) firings(1:numspikes,2)]];
  end
  
  numspikes = 0;
  
  
end;


if draw
    figure(7)
    %close all
    subplot(2,1,1)
    m=IN(8);
    plot([vhist(m,:)' ampahist(m,:)' gabaahist(m,:)' itothist(m,:)' uhist(m,:)'])
    legend('vhist(m,:)', 'ampahist(m,:)', 'gabaahist(m,:)', 'itothist(m,:)', 'u' )
    
    subplot(2,1,2)
    plot(spikeout(:,1),spikeout(:,2),'.');
    
end


% Calculate cost measure

% build the desired mfr profile.


% all_target_mfr = zeros(npatterns,length(A));
% for pat = 1:npatterns
% 
%     std=length(A)/10;
%     radius=floor(length(A)/2);
%     target_mfr=(gaussian_1d(std,-radius:radius));
%     target_mfr=target_mfr*(1/max(target_mfr)).*30; % winner fires at 30 hz.
%     [m max_index]=max( patterns{pat} );  % shift s.t. winner is in the middle.
%     target_mfr=target_mfr(1:length(A));
%     % assumes length(A) = length(r);
%     mid = ceil(length(A)/2);
%     shift = abs(max_index - mid)
%     if max_index > mid
%         % shift right
%         target_mfr=[ target_mfr(end-shift+1:end) target_mfr(1:end-shift)];
%     else
%         % mid >= max_index
%         target_mfr=[ target_mfr(shift+1:end) target_mfr(1:shift)];    
%     end
% 
%     all_target_mfr(pat,:)=target_mfr;
% end
    
% bin the spike data by pattern.
if ~isempty(spikeout)    
    S=sparse( spikeout(:,2), spikeout(:,1), 1, N , stop_sec*1000 );

    % Gather mean rates for stats.

    t=start_save_sec*1000 + 1;
    count = zeros(1,npatterns);
    mfr = zeros(npatterns,N);
    for i = start_save_sec*1000:epoch_ms:stop_sec*1000-1
        for p = 1:npatterns
            mfr(p, : ) = mfr(p, :) + (1000/isi_ms)*full(sum( S( :, t+20:(t+isi_ms-1)), 2))';
            count(p)=count(p)+1;
            t=t+isi_ms;
        end
    end
    mfr=mfr./(repmat( count', 1, size(mfr,2) ) + .000001);


    % Look at sparsness measures
    %     
    %     j= 1;
    %     for i = max(IN)+1:Ne
    %         lifetime_sparseness(j) = sparseness( mfr(:,i), 40 ); % 40 Hz is max
    %         j= j + 1;
    %     end
    %     disp(['Mean lifetime sparseness: ' num2str( mean(lifetime_sparseness))]);
    %     
    %     for i = 1:patterns
    %         population_sparseness(i) = sparseness(mfr(i,max(IN)+1:Ne), 40);
    %     end
    %     disp(['Mean population sparseness: ' num2str( mean(population_sparseness))]);

        % save the best
    %    d= sqrt( (mean(population_sparseness)-0.5)^2 )
    %d= sqrt( (mean(population_sparseness)-0.5)^2 + (mean(lifetime_sparseness)-0.5)^2)

    cost_measure = sqrt( sum(sum((mfr(:,max(IN)+1:Ne) - all_target_mfr ).^2 )))

else
    cost_measure = 1e10;
end


if draw

    figure(8)
    plot(mfr(:,A)','--')
    hold on;plot(all_target_mfr')
    hold off;
    
    
    a=(s{1}+s{2}+s{3}+s{4}+s{5});
    w=full(a(IN,A));
    visualize_som(w,5,10)
end




%
% possible arguments:
%  arbor_type == 'gauss'
%  'std', 'minw', 'maxrangew', 'dist', 'delayrange', 'lrate';
%
function connect( from_indices, to_indices, arbor_type, varargin)

global s sd syntype N D syndata


L = length( syndata);

syndata(L+1).post=to_indices;
syndata(L+1).pre =from_indices;


for k= 1 :2: size(varargin,2) 

        switch lower( varargin{k} )
          case {'std'}
            std = varargin{k+1};
          case 'initoffsetw'
            syndata(L+1).minw = varargin{k+1};
            minw = varargin{k+1};
          case 'initrangew'
            syndata(L+1).maxw = varargin{k+1};
            maxw = varargin{k+1};
          case 'radius'
            radius = round(varargin{k+1});
          case 'delaymin'
            delaymin = round(varargin{k+1});
          case 'delayrange'
            delayrange = round(varargin{k+1});
            if delaymin + delayrange-1 > D
                disp('delay > max delay in connect. Using max.');
                delayrange
                delaymin
                delayrange = D-delaymin + 1;
            end
          case 'lrate'
            syndata(L+1).lrate = varargin{k+1};
          case 'minw'
            syndata(L+1).min = varargin{k+1};
          case 'maxrangew'
            syndata(L+1).max = syndata(L+1).min + varargin{k+1};
          otherwise
            disp(['Unknown param.' varargin{k}])
        end

end

% Assume wraparound.

% clip radius to not touch target neurons more than once.
if radius > length(to_indices)/2-.5, radius = floor(length(to_indices)/2-.5);end


p=[to_indices to_indices to_indices]; %randperm(N);
% create the list of indices for each neuron; project topographically.
step = length(to_indices)/length(from_indices);
rel_index = .5;

if strcmp(arbor_type, 'gauss')
    w = std*sqrt(2*pi)* gaussian_1d(std, -round(radius):round(radius) );
    w = minw+maxw *(1/sum(w))*w;
elseif strcmp(arbor_type, 'surround')
    % surround
    x=-round(radius):round(radius);
    w=std*sqrt(2*pi)* abs(x).*gaussian_1d(std, x );
    w = minw+maxw*(1/sum(w))*w;
elseif strcmp(arbor_type, 'cos')
    w=cos(-1.5*pi:1.5*pi/radius:1.5*pi);
    w(w<0)=0;
    w=minw+maxw*(1/sum(w))*w;  % make them sum to maxw.
elseif strcmp(arbor_type, 'cossurround')
    w=cos(-1.5*pi:1.5*pi/radius:1.5*pi);
    w(w>0)=0;
    w=-w; % make these weights positive for the conductance-based model.
    w=minw+maxw*(1/sum(w))*w;  % make them sum to maxw.
end
for i=from_indices

    if strcmp(arbor_type, 'rand')
        % each neuron gets different weights
        w=rand(1,2*floor(radius)+1);
        w=minw+maxw*(1/sum(w))*w;  % make them sum to maxw.
    end
    
    my_center = round(rel_index*step);
    rel_index=rel_index + 1;
    myoffset = length(to_indices) + my_center;

    target_indices = p( myoffset + (-round(radius):round(radius)) );
    
    delays = delaymin + floor( delayrange*(abs(-round(radius):round(radius))/(round(radius)+.001)));

    for j = 1:length(target_indices)
        s{delays(j)}(i,target_indices(j)) = w(j); % + noise.
        syntype{delays(j)}(i,target_indices(j)) = length( syndata );
        sd{delays(j)}(i,target_indices(j)) = 0.0000001; % + noise.        
    end

    
end;


function xa = gaussian_1d(sigma_p,x);

x0 = 0;


xa = 1/(sigma_p * sqrt(2*pi)) * exp ((-(-x-x0).^2) ./ (2*sigma_p^2));


%%
% make some oriented gabor patches.
function P = make_patterns(samples,n);


pixels = floor(floor(sqrt(n/2)));


P = zeros( samples,n);
for i = 1:samples
    cycles = 2;
    xphase = 0; %(-1+2*rand)*cycles*pi;
    yphase = 0; %(-1+2*rand)*cycles*pi;
    angle = i*pi/samples;
    contrast = .25+.75*rand;
    sd_x = 1;
    sd_y=4;

    g=contrast * gaborfilter(pixels,sd_x*pi,sd_y*pi,angle,cycles,xphase,yphase);

    if mod(i,1)==0
        imagesc(g,[-1 1])
        colormap(gray);
        drawnow;
    end

    on=g(:);
    off=-on;
    
    on(on<0) = 0;
    off(off<0)=0;
    temp=[on;off];    
    P(i,1:length(temp))= temp';

    %pause(.5);    
    % normalize
    P=P .* repmat( 1./sum(P,2),1, size(P,2) );
    
end

% display all neuron weights as a 2d array of images
% assumes the weights are square 2d arrays for each neuron
%
% w(pixels,neurons) is the som weight matrix.
%
%function visualize_som(w, pixelrows, pixelcols)

function visualize_som(w,pixelrows,pixelcols)
figure(10)
k=1;

neuronrows = floor(sqrt(size(w,2)));
neuroncols = ceil(sqrt(size(w,2)));

contrast = [min(min(w)) max(max(w))];

for i = 1:neuronrows
    for j = 1:neuronrows
        subplot(neuronrows,neuronrows,k);

        a=reshape( w(:,k), pixelrows, pixelcols );

        imagesc( a,contrast);
        axis off
        colormap(gray);
        k=k+1;
    end
end




function normalize_weights

global s sd syntype N D syndata



for i = 1:length(syndata)
    
    if syndata(i).lrate > 10^-6
    
        fromi=syndata(i).pre;
        toi=syndata(i).post;
        w=zeros(1,length(toi));
        for delay = 1:D
            w=w+sum(s{delay}(fromi,toi),1);
        end

        % now normalize s.t. total weights is preserved.
        for delay = 1:D
            s{delay}(fromi,toi) = syndata(i).minw+ syndata(i).maxw * repmat((1./(w+.00001)),length(fromi),1).*s{delay}(fromi,toi);
        end        

    end

end




